# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

"""
DESCRIPTION:
    This sample demonstrates how to use agent operations with file searching from
    the Azure Agents service using a synchronous client.

USAGE:
    python sample_agents_stream_file_search.py

    Before running the sample:

    pip install azure-ai-projects azure-ai-agents azure-identity

    Set these environment variables with your own values:
    1) PROJECT_ENDPOINT - The Azure AI Project endpoint, as found in the Overview
                          page of your Azure AI Foundry portal.
    2) MODEL_DEPLOYMENT_NAME - The deployment name of the AI model, as found under the "Name" column in
       the "Models + endpoints" tab in your Azure AI Foundry project.
"""

import os
from azure.ai.projects import AIProjectClient
from azure.ai.agents.models import (
    AgentStreamEvent,
    FilePurpose,
    FileSearchTool,
    ListSortOrder,
    MessageDeltaChunk,
    MessageDeltaTextContent,
    MessageDeltaTextFileCitationAnnotation,
    MessageDeltaTextFileCitationAnnotationObject,
    RunAdditionalFieldList,
    RunStep,
    RunStepDeltaChunk,
    RunStepFileSearchToolCall,
    RunStepToolCallDetails,
    ThreadMessage,
    ThreadRun,
)
from azure.identity import DefaultAzureCredential

asset_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../assets/product_info_1.md"))

project_client = AIProjectClient(
    endpoint=os.environ["PROJECT_ENDPOINT"],
    credential=DefaultAzureCredential(),
)

with project_client:
    agents_client = project_client.agents

    # Upload file and create vector store
    file = agents_client.files.upload_and_poll(file_path=asset_file_path, purpose=FilePurpose.AGENTS)
    print(f"Uploaded file, file ID: {file.id}")

    vector_store = agents_client.vector_stores.create_and_poll(file_ids=[file.id], name="my_vectorstore")
    print(f"Created vector store, vector store ID: {vector_store.id}")

    # Create file search tool with resources followed by creating agent
    file_search = FileSearchTool(vector_store_ids=[vector_store.id])

    agent = agents_client.create_agent(
        model=os.environ["MODEL_DEPLOYMENT_NAME"],
        name="my-agent",
        instructions="Hello, you are helpful agent and can search information from uploaded files",
        tools=file_search.definitions,
        tool_resources=file_search.resources,
    )

    print(f"Created agent, ID: {agent.id}")

    # Create thread for communication
    thread = agents_client.threads.create()
    print(f"Created thread, ID: {thread.id}")

    # Create message to thread
    message = agents_client.messages.create(
        thread_id=thread.id, role="user", content="Hello, what Contoso products do you know?"
    )
    print(f"Created message, ID: {message.id}")

    references = {}
    with agents_client.runs.stream(
        thread_id=thread.id, agent_id=agent.id, include=[RunAdditionalFieldList.FILE_SEARCH_CONTENTS]
    ) as stream:

        for event_type, event_data, _ in stream:
            if isinstance(event_data, MessageDeltaChunk):
                text = event_data.text
                if event_data.delta.content and isinstance(event_data.delta.content[0], MessageDeltaTextContent):
                    delta_text_content = event_data.delta.content[0]
                    if delta_text_content.text and delta_text_content.text.annotations:
                        for delta_annotation in delta_text_content.text.annotations:
                            if (
                                isinstance(delta_annotation, MessageDeltaTextFileCitationAnnotation)
                                and isinstance(
                                    delta_annotation.file_citation, MessageDeltaTextFileCitationAnnotationObject
                                )
                                and delta_annotation.file_citation.file_id
                                and delta_annotation.text
                            ):
                                citation = (
                                    os.path.split(asset_file_path)[-1]
                                    if delta_annotation.file_citation.file_id == file.id
                                    else delta_annotation.file_citation.file_id
                                )
                                references[delta_annotation.text] = citation
                                print(f"File citation delta received: [{citation}]")
                for ref, citation in references.items():
                    text = text.replace(ref, f" [{citation}]")
                print(f"Text delta received: {text}")

            elif isinstance(event_data, RunStepDeltaChunk):
                print(f"RunStepDeltaChunk received. ID: {event_data.id}.")

            elif isinstance(event_data, ThreadMessage):
                print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}")

            elif isinstance(event_data, ThreadRun):
                print(f"ThreadRun status: {event_data.status}")

                if event_data.status == "failed":
                    print(f"Run failed. Error: {event_data.last_error}")

            elif isinstance(event_data, RunStep):
                print(f"RunStep type: {event_data.type}, Status: {event_data.status}")
                if isinstance(event_data.step_details, RunStepToolCallDetails):
                    for tool_call in event_data.step_details.tool_calls:
                        if (
                            isinstance(tool_call, RunStepFileSearchToolCall)
                            and tool_call.file_search
                            and tool_call.file_search.results
                            and tool_call.file_search.results[0].content
                            and tool_call.file_search.results[0].content[0].text
                        ):
                            print(
                                "The search tool has found the next relevant content in "
                                f"the file {tool_call.file_search.results[0].file_name}:"
                            )
                            # Note: technically we may have several search results, however in our example
                            # we only have one file, so we are taking the only result.
                            print(tool_call.file_search.results[0].content[0].text)
                            print("===============================================================")

            elif event_type == AgentStreamEvent.ERROR:
                print(f"An error occurred. Data: {event_data}")

            elif event_type == AgentStreamEvent.DONE:
                print("Stream completed.")

            else:
                print(f"Unhandled Event Type: {event_type}, Data: {event_data}")

    # Delete the file when done
    agents_client.vector_stores.delete(vector_store.id)
    print("Deleted vector store")

    agents_client.files.delete(file_id=file.id)
    print("Deleted file")

    # Delete the agent when done
    agents_client.delete_agent(agent.id)
    print("Deleted agent")

    # Fetch and log all messages
    messages = agents_client.messages.list(thread_id=thread.id, order=ListSortOrder.ASCENDING)

    # Print last messages from the thread
    file_name = os.path.split(asset_file_path)[-1]
    for msg in messages:
        if msg.text_messages:
            last_text = msg.text_messages[-1].text.value
            for annotation in msg.text_messages[-1].text.annotations:
                citation = (
                    file_name if annotation.file_citation.file_id == file.id else annotation.file_citation.file_id
                )
                last_text = last_text.replace(annotation.text, f" [{citation}]")
            print(f"{msg.role}: {last_text}")
